package org.bdware.sc.visitor;

import org.antlr.v4.runtime.misc.Interval;
import org.antlr.v4.runtime.tree.TerminalNodeImpl;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.bdware.sc.node.AnnotationNode;
import org.bdware.sc.node.FunctionNode;
import org.bdware.sc.node.Op;
import org.bdware.sc.node.StmtNode;
import org.bdware.sc.node.stmt.*;
import org.bdware.sc.parser.YJSParser;
import org.bdware.sc.parser.YJSParser.*;
import org.bdware.sc.parser.YJSParserBaseVisitor;

import java.util.ArrayList;
import java.util.List;
import java.util.Stack;

public class FunctionReader extends YJSParserBaseVisitor<FunctionNode> {
    private static final Logger LOGGER = LogManager.getLogger(FunctionReader.class);
    FunctionNode node;
    Stack<BlockInterval> blockStack;
    String fileName;
    // Stack<String> regStack;
    int regID;

    public FunctionReader(String fileName) {
        this.fileName = fileName;
    }

    @Override
    public FunctionNode visitMethodDefinition(YJSParser.MethodDefinitionContext ctx) {
        node = new FunctionNode(ctx.propertyName().identifierName().getText(), fileName);
        blockStack = new Stack<>();
        // regStack = new Stack<>();
        regID = 0;
        initParams(ctx.formalParameterList());
        initFunctionBody(ctx.functionBody());
        node.setLine(ctx.start.getLine());
        node.setPos(ctx.start.getCharPositionInLine());
        node.setInterval(ctx.getSourceInterval());
        return node;
    }

    @Override
    public FunctionNode visitFunctionDeclaration(YJSParser.FunctionDeclarationContext ctx) {
        node = new FunctionNode(ctx.Identifier().toString(), fileName);
        node.setIsExport(null != ctx.Export());
        node.setView(null != ctx.View());
        blockStack = new Stack<>();
        initParams(ctx.formalParameterList());
        initFunctionBody(ctx.functionBody());
        node.setLine(ctx.start.getLine());
        node.setPos(ctx.start.getCharPositionInLine());

        node.setInterval(
                new Interval(ctx.Function().getSourceInterval().a, ctx.getSourceInterval().b));

        List<AnnotationContext> annotations = new ArrayList<>();
        if (null != ctx.annotations()) {
            annotations = ctx.annotations().annotation();
        }
        for (AnnotationContext annotation : annotations) {
            AnnotationNode annNode = new AnnotationNode(annotation.Identifier().toString());
            if (annNode.getType().equals("Mask")) {
                node.setIsMask(true);
            }
            if (null != annotation.annotationArgs())
                for (AnnotationLiteralContext tNode :
                        annotation.annotationArgs().annotationLiteral()) {
                    if (null != tNode.numericLiteral()) {
                        annNode.addArg(tNode.numericLiteral().getText());
                        LOGGER.debug(
                                "------AnnotationNumericArgs:" + tNode.numericLiteral().getText());
                    } else if (null != tNode.StringLiteral()) {
                        annNode.addArg(tNode.StringLiteral().getText());
                        LOGGER.debug(
                                "------AnnotationStringArgs:" + tNode.StringLiteral().getText());
                    } else {
                        annNode.addArg(tNode.objectLiteral().getText());
                        LOGGER.debug(
                                "------AnnotationObjectArgs:" + tNode.objectLiteral().getText());
                    }
                }
            node.addAnnotation(annNode);
        }

        return node;
    }

    private void initFunctionBody(FunctionBodyContext functionBody) {
        if (null == functionBody || null == functionBody.sourceElements()) {
            return;
        }
        List<SourceElementContext> sourceElements = functionBody.sourceElements().sourceElement();
        if (null != sourceElements)
            for (SourceElementContext ctx : sourceElements) {
                StatementContext stmt = ctx.statement();
                visitStatement(stmt);
            }
    }

    private void expandStmt(FunctionNode node, StatementContext stmt) {
        // stmt.start.getLine();
        stmt.accept(this);
        if (null != stmt.block()) {
            StatementListContext list = stmt.block().statementList();
            for (StatementContext ctx : list.statement()) {
                expandStmt(node, ctx);
            }
        }
        /*
         * | variableStatement | emptyStatement | expressionStatement | ifStatement |
         * iterationStatement | continueStatement | breakStatement | returnStatement |
         * labelledStatement | switchStatement
         */
    }

    private void initParams(FormalParameterListContext argList) {
        if (null == argList) {
            return;
        }
        List<FormalParameterArgContext> args = argList.formalParameterArg();
        for (FormalParameterArgContext arg : args) {
            node.addArg(arg.Identifier().toString());
        }
    }

    public FunctionNode visitBlock(YJSParser.BlockContext ctx) {
        LabelStmt start = new LabelStmt();
        start.setLineAndPos(ctx.start);
        LabelStmt end = new LabelStmt();
        end.setLineAndPos(ctx.stop);
        node.addStmt(start);
        blockStack.push(new BlockInterval(start, end));
        if (null != ctx.statementList()) {
            ctx.statementList().accept(this);
        }
        node.addStmt(end);
        blockStack.pop();
        return node;
    }

    // ========Now we can handle the stmts========

    public FunctionNode visitVariableStatement(YJSParser.VariableStatementContext ctx) {
        List<VariableDeclarationContext> list = ctx.variableDeclarationList().variableDeclaration();
        for (VariableDeclarationContext varCtx : list) {
            varCtx.accept(this);
        }
        return node;
    }

    public FunctionNode visitVariableDeclaration(YJSParser.VariableDeclarationContext ctx) {
        Stmt2N stmt = new Stmt2N(Op.Move);
        stmt.setLineAndPos(ctx.start);
        stmt.setTo(ctx.Identifier().toString());
        if (null != ctx.singleExpression()) {
            node.addStmts(parseSingleExpression(ctx.singleExpression()));
            stmt.setFrom("Reg" + (regID - 1));
        } else {
            stmt.setFrom("undefined");
            node.addStmt(stmt);
        }
        return node;
    }

    public FunctionNode visitEmptyStatement(YJSParser.EmptyStatementContext ctx) {
        return node;
    }

    public FunctionNode visitExpressionStatement(YJSParser.ExpressionStatementContext ctx) {
        List<SingleExpressionContext> list = ctx.expressionSequence().singleExpression();
        for (SingleExpressionContext expression : list) {
            node.addStmts(parseSingleExpression(expression));
        }
        return node;
    }

    public FunctionNode visitIfStatement(YJSParser.IfStatementContext ctx) {
        BranchStmt stmt = new BranchStmt();
        LabelStmt start = new LabelStmt();
        LabelStmt end = new LabelStmt();
        List<StmtNode> l = parseExpressionSequence(ctx.expressionSequence());
        node.addStmts(l);
        stmt.setReg("Reg" + (regID - 1));
        stmt.setTarget(end);
        node.addStmt(stmt);
        node.addStmt(start);
        List<StatementContext> subBlock = ctx.statement();
        this.visitStatement(subBlock.get(0));
        node.addStmt(end);
        if (null != ctx.Else()) {
            GotoStmt gotoStmt = new GotoStmt();
            gotoStmt.setTarget(end);
            gotoStmt.setLineAndPos(subBlock.get(1).start);
            node.addStmt(gotoStmt);
            this.visitStatement(subBlock.get(1));
        }
        node.addStmt(end);
        return node;
    }

    public FunctionNode visitForStatement(YJSParser.ForStatementContext ctx) {
        int order = 2;
        int index = 0;
        ExpressionSequenceContext prePart, ifPart, tailPart;
        if (ctx.getChild(order) instanceof TerminalNodeImpl &&
                ctx.getChild(order).toString().equals(";")) {
            prePart = null;
            order++;
        } else {
            prePart = ctx.expressionSequence(index);
            order += 2;
            index++;
        }
        if (ctx.getChild(order) instanceof TerminalNodeImpl &&
                ctx.getChild(order).toString().equals(";")) {
            ifPart = null;
            order++;
        } else {
            order += 2;
            ifPart = ctx.expressionSequence(index);
            index++;
        }

        if (ctx.getChild(order) instanceof TerminalNodeImpl &&
                ctx.getChild(order).toString().equals(")")) {
            tailPart = null;
        } else {
            tailPart = ctx.expressionSequence(index);
        }
        LabelStmt start = new LabelStmt();
        start.setLineAndPos(ctx.start);
        LabelStmt end = new LabelStmt();
        end.setLineAndPos(ctx.stop);
        LabelStmt target = new LabelStmt();
        target.setLineAndPos(ctx.start);
        // handle declaring?
        if (null != prePart) {
            prePart.accept(this);
        }
        node.addStmt(start.setLineAndPos(ctx.start));
        node.addStmt(new GotoStmt().setTarget(target).setLineAndPos(ctx.start));
        if (null != tailPart) {
            tailPart.accept(this);
        }
        node.addStmt(target);
        if (null != ifPart) {
            ifPart.accept(this);
            node.addStmt(
                    new BranchStmt()
                            .setReg("Reg" + (regID - 1))
                            .setTarget(end)
                            .setLineAndPos(ctx.start));
        }
        blockStack.push(new BlockInterval(start, end));
        visitStatement(ctx.statement());
        blockStack.pop();
        node.addStmt(new GotoStmt().setTarget(start).setLineAndPos(ctx.start));
        node.addStmt(end);
        return node;
    }

    public FunctionNode visitContinueStatement(YJSParser.ContinueStatementContext ctx) {
        GotoStmt stmt = new GotoStmt();
        stmt.setLineAndPos(ctx.start);
        stmt.setTarget(blockStack.peek().start);
        node.addStmt(stmt);
        return node;
    }

    public FunctionNode visitBreakStatement(YJSParser.BreakStatementContext ctx) {
        GotoStmt stmt = new GotoStmt();
        stmt.setLineAndPos(ctx.start);
        stmt.setTarget(blockStack.peek().end);
        node.addStmt(stmt);
        return node;
    }

    public FunctionNode visitReturnStatement(YJSParser.ReturnStatementContext ctx) {
        node.addStmts(parseExpressionSequence(ctx.expressionSequence()));
        if (null == ctx.expressionSequence()) {
            node.addStmt(new Stmt1N(Op.Return, null).setLineAndPos(ctx.start));
        } else {
            node.addStmt(new Stmt1N(Op.Return, "Reg" + (regID - 1)).setLineAndPos(ctx.start));
        }
        return node;
    }

    public FunctionNode visitSwitchStatement(YJSParser.SwitchStatementContext ctx) {
        return null;
    }

    private List<StmtNode> parseExpressionSequence(ExpressionSequenceContext expressionSequence) {
        List<StmtNode> ret = new ArrayList<>();

        for (SingleExpressionContext singleExpression : expressionSequence.singleExpression()) {
            ret.addAll(parseSingleExpression(singleExpression));
        }
        return ret;
    }

    private List<StmtNode> parseSingleExpression(SingleExpressionContext singleExpression) {
        List<StmtNode> ret = new ArrayList<>();
        ret.add(new Stmt1N(Op.STUB, "Reg" + (regID++)).setLineAndPos(singleExpression.start));

        return ret;
    }

    static class BlockInterval {
        LabelStmt start, end;

        public BlockInterval(LabelStmt s, LabelStmt e) {
            start = s;
            end = e;
        }
    }
}
